import numpy as np
from scipy.stats import norm, t


def random(theta_mean, theta_cov, X, **kwargs):
    """
    Computes uniform scores across all data points. Useful for random acquisition function.
    :param theta_mean: (numpy array) Posterior mean. Unused.
    :param theta_cov: (numpy array) Posterior covariance matrix. Unused.
    :param X: (numpy array) Unlabeled data.
    :param kwargs: Additional arguments. Unused.
    :return: (numpy array) Uniform acquisition function scores.
    """
    return np.ones(len(X)) * 1 / len(X)


def get_statistics(X1, X2, theta_cov):
    """
    Computes data-dependent quantities that are frequently used by acquisition functions.
    :param X1: (numpy array) One or more unlabeled data points.
    :param X2: (numpy array) Single unlabeled data point.
    :param theta_cov: (numpy array) Posterior covariance matrix.
    :return: Frequently used matrix-vector products.
    """
    X1 = np.atleast_2d(X1)
    X2 = np.atleast_2d(X2)
    assert len(X1) >= len(X2)
    xx = np.sum(X1 * X2, axis=-1)
    xSx = np.sum(X1 @ theta_cov * X2, axis=-1)
    return xx, xSx


### (multi-class) classification
def class_bald(theta_mean, theta_cov, X, model=None, num_samples=100):
    """
    Computes BALD acquisition function for categorical predictive posterior.
    Note that the function is implemented in PyTorch, not in Numpy as before.
    :param theta_mean: (numpy array) Posterior mean. Unused.
    :param theta_cov: (numpy array) Posterior covariance matrix. Unused.
    :param X: (numpy array) Unlabeled data.
    :param model: (nn.module) Classifier.
    :param num_samples: (int) Number of Monte-Carlo samples to approximate expectations
    :return: BALD acquisition function scores for categorical predictive posterior.
    """
    import torch
    hc = lambda l: torch.distributions.Categorical(logits=l).entropy()
    logits = model._compute_predictive_posterior(model.linear(X, num_samples=num_samples))
    bald_term_1 = hc(logits)
    bald_term_2 = torch.mean(hc(model.linear(X, num_samples=num_samples)), dim=0)
    return (bald_term_1 - bald_term_2).cpu().numpy()
